import torch
import torch.nn as nn

As = torch.randn(3, 2, 5)
Bs = torch.randn(3, 4, 5)

# A2B = nn.Parameter(torch.Tensor(5, 6, 5))
A2B_linear = nn.Linear(5, 10, 5)
# As_ = torch.matmul(As, A2B)
As__ = A2B_linear(As)
# print(As_)
# print(As_.size())
print(As__.size())

AB = torch.einsum('bik,bjk->bij', (As, Bs))  # batch matrix multiplication
print(AB.size())
print(AB)
